# evaluating claude model on BBH

import anthropic
import json

import numpy as np

from tqdm import tqdm
from tenacity import retry, stop_after_attempt, wait_chain, wait_fixed

# parse arguments
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--api_key", type=str, default="sk")
parser.add_argument("--model_index", type=str, default="claude-v1")
parser.add_argument("--task", type=str, default="all", choices=["all", "multiple_choice", "free_form"])
args = parser.parse_args()

MULTIPLE_CHOICE_TASKS = [
    "temporal_sequences",
    "disambiguation_qa",
    "date_understanding",
    "tracking_shuffled_objects_three_objects",
    "penguins_in_a_table",
    "geometric_shapes",
    "snarks",
    "ruin_names",
    "tracking_shuffled_objects_seven_objects",
    "tracking_shuffled_objects_five_objects",
    "logical_deduction_three_objects",
    "hyperbaton",
    "logical_deduction_five_objects",
    "logical_deduction_seven_objects",
    "movie_recommendation",
    "salient_translation_error_detection",
    "reasoning_about_colored_objects",
]
FREE_FORM_TASKS = [
    "multistep_arithmetic_two",
    "navigate",
    "dyck_languages",
    "word_sorting",
    "sports_understanding",
    "boolean_expressions",
    "object_counting",
    "formal_fallacies",
    "causal_judgement",
    "web_of_lies",
]


@retry(wait=wait_chain(*[wait_fixed(3) for i in range(3)] + [wait_fixed(5) for i in range(2)] + [wait_fixed(10)]))
def completion_with_backoff(model_index, messages):
    api_key = ""

    client = anthropic.Client(api_key)
    response = client.completion(
        prompt=f"{anthropic.HUMAN_PROMPT} {messages}{anthropic.AI_PROMPT}",
        stop_sequences=[anthropic.HUMAN_PROMPT],
        model=model_index,
        max_tokens_to_sample=1000,
    )
    print(response["completion"])
    return response["completion"]


def extract_ans(ans, mode):
    ans_line = ans.split("answer is ")
    # Expect to see 'answer is'. If not return whole string
    if len(ans_line) == 1:
        return ans
    else:
        ans = ans_line[-1].strip()

    if mode == "multiple_choice":
        options = [
            "(A)",
            "(B)",
            "(C)",
            "(D)",
            "(E)",
            "(F)",
            "(G)",
            "(H)",
            "(I)",
            "(J)",
            "(K)",
            "(L)",
            "(M)",
            "(N)",
            "(O)",
            "(P)",
            "(Q)",
            "(R)",
            "(S)",
            "(T)",
            "(U)",
            "(V)",
            "(W)",
            "(X)",
            "(Y)",
            "(Z)",
        ]
        for option in options:
            if option in ans:
                ans = option[1]
                break
        return ans
    elif mode == "free_form":
        if ans[-1] == ".":
            ans = ans[:-1]
        return ans


def run_tasks(tasks, mode, model_index="claude_instant_v1.0"):
    for task in tasks:
        print("Testing %s ..." % task)
        acc = 0
        task_data = json.load(open("data/%s.json" % task))
        task_prompt = open("lib_prompt/%s.txt" % task, "r").read()
        print_first = True
        with open("outputs/test_claude_instant_v1.0_%s.txt" % task, "w") as fd:
            for q_ in tqdm(task_data["examples"]):
                q = "\n\nQ: " + q_["input"]

                prompt_q = task_prompt + q + "\nA: Let's think step by step."

                if print_first:
                    print("First prompt: ")
                    print(prompt_q)
                    print_first = False

                response = completion_with_backoff(model_index, prompt_q)

                ans_model = response
                ans_ = extract_ans(ans_model, mode)

                if mode == "multiple_choice":
                    a = q_["target"][1]
                elif mode == "free_form":
                    a = q_["target"]

                fd.write("%s\nA_model:\n%s\nA_target:\n%s\n\n" % (q, ans_model, a))

                if ans_ == a:
                    acc += 1
            print("%s acc %.4f" % (task, acc / len(task_data["examples"])))
            fd.write("%s acc %.4f" % (task, acc / len(task_data["examples"])))


def main(args, multiple_choice_tasks=MULTIPLE_CHOICE_TASKS, free_form_tasks=FREE_FORM_TASKS):
    api_key = args.api_key
    model_index = args.model_index
    run_multiple_choice = args.task == "all" or args.task == "multiple_choice"
    run_free_form = args.task == "all" or args.task == "free_form"

    if run_multiple_choice:
        run_tasks(multiple_choice_tasks, mode="multiple_choice", model_index=model_index)
    if run_free_form:
        run_tasks(free_form_tasks, mode="free_form", model_index=model_index)
    return


if __name__ == "__main__":
    main(args)
